{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Language Models as Zero-Shot Planners:<br>Extracting Actionable Knowledge for Embodied Agents\n",
    "\n",
    "This is the official demo code for our [Language Models as Zero-Shot Planners](https://huangwl18.github.io/language-planner/) paper. The code demonstrates how Large Language Models, such as GPT-3 and Codex, can generate action plans for complex human activities (e.g. \"make breakfast\"), even without any further training. The code can be used with any available language models from [OpenAI API](https://openai.com/api/) and [Huggingface Transformers](https://huggingface.co/docs/transformers/index) with a common interface.\n",
    "\n",
    "**Note:**\n",
    "- It is observed that best results can be obtained with larger language models. If you cannot run [Huggingface Transformers](https://huggingface.co/models?pipeline_tag=text-generation&sort=downloads) models locally or on Google Colab due to memory constraint, it is recommended to register an [OpenAI API](https://openai.com/api/) account and use GPT-3 or Codex (As of 01/2022, $18 free credits are awarded to new accounts and Codex series are free after [admitted from the waitlist](https://share.hsforms.com/1GzaACuXwSsmLKPfmphF_1w4sk30?)).\n",
    "- Due to language models' high sensitivity to sampling hyperparameters, you may need to tune sampling hyperparameters for different models to obtain the best results.\n",
    "- The code uses the list of available actions supported in [VirtualHome 1.0](https://github.com/xavierpuigf/virtualhome/tree/v1.0.0)'s [Evolving Graph Simulator](https://github.com/xavierpuigf/virtualhome/tree/v1.0.0/simulation). The available actions are stored in [`available_actions.json`](https://github.com/huangwl18/language-planner/blob/master/src/available_actions.json). The actions should support a large variety of household tasks. However, you may modify or replace this file if you're interested in a different set of actions or a different domain of tasks (beyond household domain).\n",
    "- A subset of the [manually-annotated examples](http://virtual-home.org/release/programs/programs_processed_precond_nograb_morepreconds.zip) originally collected by the [VirtualHome paper](https://arxiv.org/pdf/1806.07011.pdf) is used as available examples in the prompt. They are transformed to natural language format and stored in [`available_examples.json`](https://github.com/huangwl18/language-planner/blob/master/src/available_examples.json). Feel free to change this file for a different set of available examples."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Setup\n",
    "If you're on Colab, do the following:\n",
    "1. run the following cell to setup repo and install dependencies.\n",
    "2. enable GPU support (Runtime -> Change Runtime Type -> Hardware Accelerator -> set to GPU)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# !git clone https://github.com/huangwl18/language-planner\n",
    "# %cd language-planner\n",
    "# !pip install -r requirements.txt\n",
    "# %cd src"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Import packages and specify which GPU to be used"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import openai\n",
    "import numpy as np\n",
    "import torch\n",
    "from sentence_transformers import SentenceTransformer\n",
    "from sentence_transformers import util as st_utils\n",
    "import json\n",
    "\n",
    "GPU = 0\n",
    "if torch.cuda.is_available():\n",
    "    torch.cuda.set_device(GPU)\n",
    "OPENAI_KEY = None  # replace this with your OpenAI API key, if you choose to use OpenAI API"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define hyperparemeters for plan generation\n",
    "\n",
    "Select the source to be used (**OpenAI API** or **Huggingface Transformers**) and the two LMs to be used (**Planning LM** for plan generation, **Translation LM** for action/example matching). Then define generation hyperparameters.\n",
    "\n",
    "Available language models for **Planning LM** can be found from:\n",
    "- [Huggingface Transformers](https://huggingface.co/models?pipeline_tag=text-generation&sort=downloads)\n",
    "- [OpenAI API](https://beta.openai.com/docs/engines) (you would need to paste your OpenAI API key from your account to `openai.api_key` below)\n",
    "\n",
    "Available language models for **Translation LM** can be found from:\n",
    "- [Sentence Transformers](https://huggingface.co/sentence-transformers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "source = 'huggingface'  # select from ['openai', 'huggingface']\n",
    "planning_lm_id = 'gpt2-large'  # see comments above for all options\n",
    "translation_lm_id = 'stsb-roberta-large'  # see comments above for all options\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "MAX_STEPS = 20  # maximum number of steps to be generated\n",
    "CUTOFF_THRESHOLD = 0.8  # early stopping threshold based on matching score and likelihood score\n",
    "P = 0.5  # hyperparameter for early stopping heuristic to detect whether Planning LM believes the plan is finished\n",
    "BETA = 0.3  # weighting coefficient used to rank generated samples\n",
    "if source == 'openai':\n",
    "    openai.api_key = OPENAI_KEY\n",
    "    sampling_params = \\\n",
    "            {\n",
    "                \"max_tokens\": 10,\n",
    "                \"temperature\": 0.6,\n",
    "                \"top_p\": 0.9,\n",
    "                \"n\": 10,\n",
    "                \"logprobs\": 1,\n",
    "                \"presence_penalty\": 0.5,\n",
    "                \"frequency_penalty\": 0.3,\n",
    "                \"stop\": '\\n'\n",
    "            }\n",
    "elif source == 'huggingface':\n",
    "    sampling_params = \\\n",
    "            {\n",
    "                \"max_tokens\": 10,\n",
    "                \"temperature\": 0.1,\n",
    "                \"top_p\": 0.9,\n",
    "                \"num_return_sequences\": 10,\n",
    "                \"repetition_penalty\": 1.2,\n",
    "                'use_cache': True,\n",
    "                'output_scores': True,\n",
    "                'return_dict_in_generate': True,\n",
    "                'do_sample': True,\n",
    "            }"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Planning LM Initialization\n",
    "Initialize **Planning LM** from either **OpenAI API** or **Huggingface Transformers**. Abstract away the underlying source by creating a generator function with a common interface."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def lm_engine(source, planning_lm_id, device):\n",
    "    if source == 'huggingface':\n",
    "        from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "        tokenizer = AutoTokenizer.from_pretrained(planning_lm_id)\n",
    "        model = AutoModelForCausalLM.from_pretrained(planning_lm_id, pad_token_id=tokenizer.eos_token_id).to(device)\n",
    "\n",
    "    def _generate(prompt, sampling_params):\n",
    "        if source == 'openai':\n",
    "            response = openai.Completion.create(engine=planning_lm_id, prompt=prompt, **sampling_params)\n",
    "            generated_samples = [response['choices'][i]['text'] for i in range(sampling_params['n'])]\n",
    "            # calculate mean log prob across tokens\n",
    "            mean_log_probs = [np.mean(response['choices'][i]['logprobs']['token_logprobs']) for i in range(sampling_params['n'])]\n",
    "        elif source == 'huggingface':\n",
    "            input_ids = tokenizer(prompt, return_tensors=\"pt\").input_ids.to(device)\n",
    "            prompt_len = input_ids.shape[-1]\n",
    "            output_dict = model.generate(input_ids, max_length=prompt_len + sampling_params['max_tokens'], **sampling_params)\n",
    "            # discard the prompt (only take the generated text)\n",
    "            generated_samples = tokenizer.batch_decode(output_dict.sequences[:, prompt_len:])\n",
    "            # calculate per-token logprob\n",
    "            vocab_log_probs = torch.stack(output_dict.scores, dim=1).log_softmax(-1)  # [n, length, vocab_size]\n",
    "            token_log_probs = torch.gather(vocab_log_probs, 2, output_dict.sequences[:, prompt_len:, None]).squeeze(-1).tolist()  # [n, length]\n",
    "            # truncate each sample if it contains '\\n' (the current step is finished)\n",
    "            # e.g. 'open fridge\\n<|endoftext|>' -> 'open fridge'\n",
    "            for i, sample in enumerate(generated_samples):\n",
    "                stop_idx = sample.index('\\n') if '\\n' in sample else None\n",
    "                generated_samples[i] = sample[:stop_idx]\n",
    "                token_log_probs[i] = token_log_probs[i][:stop_idx]\n",
    "            # calculate mean log prob across tokens\n",
    "            mean_log_probs = [np.mean(token_log_probs[i]) for i in range(sampling_params['num_return_sequences'])]\n",
    "        generated_samples = [sample.strip().lower() for sample in generated_samples]\n",
    "        return generated_samples, mean_log_probs\n",
    "\n",
    "    return _generate\n",
    "\n",
    "generator = lm_engine(source, planning_lm_id, device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Translation LM Initialization\n",
    "Initialize **Translation LM** and create embeddings for all available actions (for action translation) and task names of all available examples (for finding relevant example)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# initialize Translation LM\n",
    "translation_lm = SentenceTransformer(translation_lm_id).to(device)\n",
    "\n",
    "# create action embeddings using Translated LM\n",
    "with open('available_actions.json', 'r') as f:\n",
    "    action_list = json.load(f)\n",
    "action_list_embedding = translation_lm.encode(action_list, batch_size=512, convert_to_tensor=True, device=device)  # lower batch_size if limited by GPU memory\n",
    "\n",
    "# create example task embeddings using Translated LM\n",
    "with open('available_examples.json', 'r') as f:\n",
    "    available_examples = json.load(f)\n",
    "example_task_list = [example.split('\\n')[0] for example in available_examples]  # first line contains the task name\n",
    "example_task_embedding = translation_lm.encode(example_task_list, batch_size=512, convert_to_tensor=True, device=device)  # lower batch_size if limited by GPU memory\n",
    "\n",
    "# helper function for finding similar sentence in a corpus given a query\n",
    "def find_most_similar(query_str, corpus_embedding):\n",
    "    query_embedding = translation_lm.encode(query_str, convert_to_tensor=True, device=device)\n",
    "    # calculate cosine similarity against each candidate sentence in the corpus\n",
    "    cos_scores = st_utils.pytorch_cos_sim(query_embedding, corpus_embedding)[0].detach().cpu().numpy()\n",
    "    # retrieve high-ranked index and similarity score\n",
    "    most_similar_idx, matching_score = np.argmax(cos_scores), np.max(cos_scores)\n",
    "    return most_similar_idx, matching_score"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Autoregressive Plan Generation\n",
    "Generate action plan autoregressively by mapping each action step to the closest available action.\n",
    "\n",
    "The algorithm is summarized as follows:\n",
    "1. Extract the most relevant human-annotated example by matching the task names (`Make breakfast` <-> `Make toast`)\n",
    "2. Prompt the **Planning LM** with `[One Example Plan] + [Query Task] (+ [Previously Generated Actions])`\n",
    "3. Sample **Planning LM** `n` times for single-step text output.\n",
    "4. Rank the samples by `[Matching Score] + BETA * [Mean Log Prob]`, where\n",
    "    1. `[Matching Score]` is the cosine similarity to the closest allowed action as determined by **Translation LM**\n",
    "    2. `[Mean Log Prob]` is the mean log probabilities across generated tokens given by **Planning LM**\n",
    "5. Append the closest allowed action of the highest-ranked sample to the prompt in Step 3\n",
    "6. Return to Step 3 and repeat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# define query task\n",
    "task = 'Make breakfast'\n",
    "# find most relevant example\n",
    "example_idx, _ = find_most_similar(task, example_task_embedding)\n",
    "example = available_examples[example_idx]\n",
    "# construct initial prompt\n",
    "curr_prompt = f'{example}\\n\\nTask: {task}'\n",
    "# print example and query task\n",
    "print('-'*10 + ' GIVEN EXAMPLE ' + '-'*10)\n",
    "print(example)\n",
    "print('-'*10 + ' EXAMPLE END ' + '-'*10)\n",
    "print(f'\\nTask: {task}')\n",
    "for step in range(1, MAX_STEPS + 1):\n",
    "    best_overall_score = -np.inf\n",
    "    # query Planning LM for single-step action candidates\n",
    "    samples, log_probs = generator(curr_prompt + f'\\nStep {step}:', sampling_params)\n",
    "    for sample, log_prob in zip(samples, log_probs):\n",
    "        most_similar_idx, matching_score = find_most_similar(sample, action_list_embedding)\n",
    "        # rank each sample by its similarity score and likelihood score\n",
    "        overall_score = matching_score + BETA * log_prob\n",
    "        translated_action = action_list[most_similar_idx]\n",
    "        # heuristic for penalizing generating the same action as the last action\n",
    "        if step > 1 and translated_action == previous_action:\n",
    "            overall_score -= 0.5\n",
    "        # find the translated action with highest overall score\n",
    "        if overall_score > best_overall_score:\n",
    "            best_overall_score = overall_score\n",
    "            best_action = translated_action\n",
    "\n",
    "    # terminate early when either the following is true:\n",
    "    # 1. top P*100% of samples are all 0-length (ranked by log prob)\n",
    "    # 2. overall score is below CUTOFF_THRESHOLD\n",
    "    # else: autoregressive generation based on previously translated action\n",
    "    top_samples_ids = np.argsort(log_probs)[-int(P * len(samples)):]\n",
    "    are_zero_length = all([len(samples[i]) == 0 for i in top_samples_ids])\n",
    "    below_threshold = best_overall_score < CUTOFF_THRESHOLD\n",
    "    if are_zero_length:\n",
    "        print(f'\\n[Terminating early because top {P*100}% of samples are all 0-length]')\n",
    "        break\n",
    "    elif below_threshold:\n",
    "        print(f'\\n[Terminating early because best overall score is lower than CUTOFF_THRESHOLD ({best_overall_score} < {CUTOFF_THRESHOLD})]')\n",
    "        break\n",
    "    else:\n",
    "        previous_action = best_action\n",
    "        formatted_action = (best_action[0].upper() + best_action[1:]).replace('_', ' ') # 'open_fridge' -> 'Open fridge'\n",
    "        curr_prompt += f'\\nStep {step}: {formatted_action}'\n",
    "        print(f'Step {step}: {formatted_action}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "232f5b83f314fecbea34be6d21efbdd7d4bed02a72211b6c9ab7feb2cb54a9b0"
  },
  "kernelspec": {
   "display_name": "Python 3.6.13 64-bit ('vh-1.0': conda)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
